'''
This is a pseudo-code to help you understand the paper.
The entire source code is planned to be released to public.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from .module import *
import torch.nn.utils.spectral_norm as sn
from .conv import *
from utils.utils import *
import numpy as np


class BVAETTS(nn.Module):
    def __init__(self, hp):
        super(BVAETTS, self).__init__()
        self.hp=hp
        self.downsample=hp.downsample
        self.diag_mask = None

        # build network
        self.Prenet = Prenet(hp)
        self.TextEnc = TextEnc(hp)
        self.Layers = nn.ModuleList([BVAE_Layer(hp.hidden_dim//2, 'F', dilation=2**0),
                                     BVAE_Layer(hp.hidden_dim//2, None, dilation=2**1),
                                     BVAE_Layer(hp.hidden_dim//2, None, dilation=2**2),
                                     BVAE_Layer(hp.hidden_dim//2, 'T', dilation=2**0),
                                     BVAE_Layer(hp.hidden_dim//2, None, dilation=2**1),
                                     BVAE_Layer(hp.hidden_dim//2, None, dilation=2**2),
                                     BVAE_Layer(hp.hidden_dim//4, 'F', dilation=2**0),
                                     BVAE_Layer(hp.hidden_dim//4, None, dilation=2**1),
                                     BVAE_Layer(hp.hidden_dim//4, None, dilation=2**2),
                                     BVAE_Layer(hp.hidden_dim//4, 'T', dilation=2**0),
                                     BVAE_Layer(hp.hidden_dim//4, None, dilation=2**1),
                                     BVAE_Layer(hp.hidden_dim//4, None, dilation=2**2)])
        self.Query = Conv1d(hp.hidden_dim//4, hp.hidden_dim, 5, bias=False)
        self.Compress = Linear(hp.hidden_dim, hp.hidden_dim//4, bias=False)
        self.Projection = Projection(hp.hidden_dim, hp.n_mel_channels)
        
        # duration predictor
        self.Duration = DurationPredictor(hp)

        
    def forward(self, text, melspec, text_lengths, mel_lengths):
        text = text[:,:text_lengths.max().item()]
        melspec = melspec[:,:,:mel_lengths.max().item()]
        B,L,T = text.size(0), text.size(1), melspec.size(2)
        self.text_mask = get_mask_from_lengths(text_lengths)
        self.mel_mask = get_mask_from_lengths(mel_lengths)
        x = (torch.arange(L).float().unsqueeze(0)/text_lengths.unsqueeze(1)).unsqueeze(1)\
             - (torch.arange(T//self.downsample).float().unsqueeze(0)/(mel_lengths//self.downsample).unsqueeze(1)).unsqueeze(2)
        self.diag_mask = (-12.5*torch.pow(x, 2)).exp()
        self.diag_mask = self.diag_mask.masked_fill(self.text_mask.unsqueeze(1), 0)
        self.diag_mask = self.diag_mask.masked_fill(self.mel_mask[:,::self.downsample].unsqueeze(-1), 0).detach()
        
        ##### Text #####
        key, value = self.TextEnc(text)
        
        ##### Audio #####
        x = self.Prenet(melspec)
        for i, layer in enumerate(self.Layers):
            if i<=3:
                x = layer.up(x)
            elif i<=6:
                x = layer.up(x)
            elif i<=9:
                x = layer.up(x)
            else:
                x = layer.up(x)
        
        pe_q = PositionalEncoding(self.hp.hidden_dim, mel_lengths/self.downsample)
        pe_k = PositionalEncoding(self.hp.hidden_dim, text_lengths, w_s=mel_lengths/self.downsample/text_lengths)
        query = self.Query(x).transpose(1,2) + pe_q
        key = key + pe_k
        h_a, align = self.get_align(query, key, value, mel_lengths, self.text_mask, self.mel_mask, self.diag_mask)

        kl_loss = 0
        for i, layer in enumerate(reversed(self.Layers)):
            if i<3:
                h_a, curr_kl = layer.down(h_a)
            elif i<6:
                h_a, curr_kl = layer.down(h_a)
            elif i<9:
                h_a, curr_kl = layer.down(h_a)
            else:
                h_a, curr_kl = layer.down(h_a)
                
        
            kl_loss += curr_kl
            
        mel_pred = self.Projection(h_a)
        duration_out = self.get_duration(value)
        recon_loss, duration_loss, align_loss = self.compute_loss(mel_pred,
                                                                  melspec,
                                                                  duration_out,
                                                                  align,
                                                                  text_lengths,
                                                                  mel_lengths,
                                                                  self.text_mask,
                                                                  self.mel_mask,
                                                                  self.diag_mask)
        
        return recon_loss, kl_loss, duration_loss, align_loss

    
    def get_align(self, q, k, v, mel_lengths, text_mask, mel_mask, diag_mask):
        q = q * self.hp.hidden_dim ** -0.5
        scores = torch.bmm(q, k.transpose(1, 2))
        scores = scores.masked_fill(text_mask.unsqueeze(1), -float('inf'))
        
        align = scores.softmax(-1)
        if self.training:
            align_oh = Jitter(F.one_hot(align.max(-1)[1], align.size(-1)), mel_lengths)
        else:
            align_oh = F.one_hot(align.max(-1)[1], align.size(-1))
        
        attn_output = torch.bmm(align + (align_oh-align).detach(), v)
        attn_output = self.Compress(attn_output).transpose(1,2)
        
        return attn_output, align
    
       
        
    def compute_loss(self, mel_pred, mel_target, duration_out, align, text_lengths, mel_lengths, text_mask, mel_mask, diag_mask):
        # Recon Loss
        recon_loss = nn.L1Loss()(mel_pred.masked_select(~mel_mask.unsqueeze(1)),
                                 mel_target.masked_select(~mel_mask.unsqueeze(1)))

        # Duration Loss
        duration_target = self.align2duration(align, mel_lengths)
        duration_target_flat = duration_target.masked_select(~text_mask)
        duration_target_flat[duration_target_flat<=0]=1
        duration_out_flat = duration_out.masked_select(~text_mask)
        duration_loss = nn.MSELoss()( torch.log(duration_out_flat+1e-5), torch.log(duration_target_flat+1e-5) )
        
        # Guide Loss
        align_losses = align*(1-diag_mask)
        align_loss = torch.mean(align_losses.masked_select(diag_mask.bool()))
        
        return recon_loss, duration_loss, align_loss
        
        
    
    def get_duration(self, value):
        durations = self.Duration(value.transpose(1,2).detach())
        return durations
    
    
    
    def align2duration(self, alignments, mel_lengths):
        max_ids = torch.max(alignments, dim=2)[1]
        max_ids_oh = F.one_hot(max_ids, alignments.size(2))
        mask = get_mask_from_lengths(mel_lengths//self.downsample).unsqueeze(-1)
        max_ids_oh.masked_fill_(mask, 0)
        durations = max_ids_oh.sum(dim=1).to(torch.float)
        return durations
    
    
    
class Generator(nn.Module):
    def __init__(self, hp):
        super(Generator, self).__init__()
        self.hp=hp
        self.text_mask=None
        self.mel_mask=None

        # build network
        self.TextEnc = TextEnc(hp)
        self.Layers = nn.ModuleList([TopDown(hp.hidden_dim//2, 'F', dilation=2**0),
                                     TopDown(hp.hidden_dim//2, None, dilation=2**1),
                                     TopDown(hp.hidden_dim//2, None, dilation=2**2),
                                     TopDown(hp.hidden_dim//2, 'T', dilation=2**0),
                                     TopDown(hp.hidden_dim//2, None, dilation=2**1),
                                     TopDown(hp.hidden_dim//2, None, dilation=2**2),
                                     TopDown(hp.hidden_dim//4, 'F', dilation=2**0),
                                     TopDown(hp.hidden_dim//4, None, dilation=2**1),
                                     TopDown(hp.hidden_dim//4, None, dilation=2**2),
                                     TopDown(hp.hidden_dim//4, 'T', dilation=2**0),
                                     TopDown(hp.hidden_dim//4, None, dilation=2**1),
                                     TopDown(hp.hidden_dim//4, None, dilation=2**2)])
        self.Compress = Linear(hp.hidden_dim, hp.hidden_dim//4, bias=False)
        self.Projection = Projection(hp.hidden_dim, hp.n_mel_channels)
        
        # duration predictor
        self.Duration = DurationPredictor(hp)
    
    
    
    def inference(self, text, alpha=1.0, temperature=1.0):
        assert len(text)==1, 'You must encode only one sentence at once'
        text_lengths = torch.tensor([text.size(1)]).to(text.device)
        key, value = self.TextEnc(text)
        durations = self.get_duration(value)
        h_a, durations = self.LengthRegulator(value, durations, alpha)
        h_a = self.Compress(h_a).transpose(1,2)
        
        for i, layer in enumerate(reversed(self.Layers)):
            if i<3:
                h_a, _ = layer.down(h_a, temperature=temperature)
            elif i<6:
                h_a, _ = layer.down(h_a, temperature=temperature)
            elif i<9:
                h_a, _ = layer.down(h_a, temperature=temperature)
            else:
                h_a, _ = layer.down(h_a, temperature=temperature)

        mel_out = self.Projection(h_a)
        
        return mel_out, durations
    
    
    
    def get_duration(self, value, mask=None):
        durations = self.Duration(value.transpose(1,2).detach(), mask)
        return durations
    
    
    
    def LengthRegulator(self, hidden_states, durations, alpha=1.0):
        durations = torch.round(durations*alpha).to(torch.long)
        durations[durations<=0]=1
        return hidden_states.repeat_interleave(durations[0], dim=1), durations